[Neuron] Add tensor parallel support for Neuron backend#13718
[Neuron] Add tensor parallel support for Neuron backend#13718JingyaHuang wants to merge 54 commits into
Conversation
… into add-neuron-backend
… into add-neuron-backend
|
|
||
| # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround. | ||
| if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): | ||
| if grad_enabled or (_parallel_config is not None and _parallel_config._cp_world_size > 1): |
There was a problem hiding this comment.
With TP, context_parallel_config can be None, we set up _parallel_config._cp_world_size for it.
…rs (huggingface#13946) SkyReels-V2 and ChronoEdit are both built on Wan, and their transformers have the same keys as WanTransformer3DModel, so they reuse convert_wan_transformer_to_diffusers (like WanVACE / WanAnimate). This lets the community GGUF builds load directly. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
fix(cosmos3): pin VAE latent norm buffers to encode output device Under sharded placement (device_map="balanced"), vae.encode() runs on the VAE's own device while the mean/inv_std buffers were pinned to x.device, causing a cross-device RuntimeError. Compute raw_mu first, then pin the normalization buffers to its device so all tensors share one device. Co-authored-by: Atharva Joshi <atjoshi@smc521ge-0036.ipp2a2.colossus.nvidia.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
…13876) * docs: fix repeated word typo in set_timesteps docstring Removed the duplicate word "schedule" from the docstring for the sigmas argument in EulerDiscreteScheduler.set_timesteps. * Update scheduling_euler_discrete.py * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for working on this.
There is lot of intrusive and model-specific changes which I think is a bit of an anti-pattern. I think it's also probably because of some of the fusion stuff that's happening inside Flux2.
More specifically, the intrusive pieces exist for one reason: Flux2 fuses projections into single Linears (SwiGLU gate+linear, and to_qkv_mlp_proj packing Q/K/V and MLP).
Contiguous column sharding is blind to that internal layout, so:
- you must reorder rows so each rank gets paired slices -> the permuters, and
- the local tensor width no longer factors as heads × head_dim or splits cleanly into qkv/mlp -> the runtime
local_*recomputation.
I opened JingyaHuang#1 to simplify some of the stuff. LMK.
Furthermore, would the changes related to fusing be the same for Flux1, for example? I think gf the layers were unfused, parallelize_module + DTensor would handle head-splitting automatically and none of this would be needed.
| config: TensorParallelConfig, | ||
| tp_plan: dict, | ||
| *, | ||
| backend: str = "default", |
There was a problem hiding this comment.
Can this not be derived from torch_device?
| return _get_projections(attn, hidden_states, encoder_hidden_states) | ||
|
|
||
|
|
||
| def _get_tp_degree(parallel_config) -> int: |
There was a problem hiding this comment.
Seems like it should be present in _modeling_parallel.py?
| @property | ||
| def _cp_world_size(self) -> int: | ||
| """Context-parallel world size, or 1 when context parallelism is not enabled. | ||
|
|
||
| Lets attention backends branch on context parallelism without dereferencing a possibly ``None`` | ||
| ``context_parallel_config`` (e.g. when only tensor parallelism is active). | ||
| """ | ||
| cp = self.context_parallel_config | ||
| if cp is None or cp._world_size is None: | ||
| return 1 | ||
| return cp._world_size |
| # On Neuron, run the index-heavy `_unpack_latents_with_ids` on CPU to avoid expensive | ||
| # device<->host syncs from the gather/scatter arithmetic, then move the result back. | ||
| latent_device = latents.device | ||
| on_neuron = get_device() == "neuron" | ||
| if on_neuron: | ||
| latents = latents.cpu() | ||
| latent_ids = latent_ids.cpu() |
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
…arallel Adopts Sayak's changes from #1 that replace the Flux2-specific _tp_fused_block_permuters (permute-then-slice) with generic PackedColwiseParallel / PackedRowwiseParallel styles that slice fused projections block-by-block. Also drops the now-unused _tp_fused_block_permuters base-class default in modeling_utils. Keeps torch.chunk in Flux2SwiGLU.forward (TorchAO compile regression fix), overriding the half-slicing on Sayak's branch. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…into support-neuron-tp
|
@JingyaHuang did my PR break any neuron-specific stuff? |
What does this PR do?
Adds tensor-parallel (TP) inference for diffusers models on AWS Neuron (Trainium/Inferentia) device. Here as suggested we use Flux2 Klein as the starting point. But the TP support here is generic, easy to extend to other backend(cuda, tpu and more) and is exposed through the existing public API used for CP:
model.enable_parallelism(config=TensorParallelConfig(...)).Key changes:
apply_tensor_parallelthat shards from a flat_tp_plan.Quick test — Flux2 TP on Neuron (For future release)
run with
torchrun --nproc_per_node=8 flux2_tp8_neuron.pyWho can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.